from Network.network import Network
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import copy
from Network.network_utils import reduce_function, get_acti
from Network.General.Flat.mlp import MLPNetwork
from Network.General.Conv.conv import ConvNetwork
from Network.General.Factor.factor_utils import final_conv_args, final_mlp_args
from Network.General.Factor.factored import return_values
from Network.General.Factor.Pair.pair import merge_key_queries

def flatten_key_queries(key, query, mask, append_keys=True, append_broadcast_mask = 0, append_mask=False):
    # mask of shape [batch, num_keys=1, num_queries]
    # keys: [batch, num_keys, key_dim], queries: [batch, num_queries, query_dim]
    # append_broadcast_mask is the size of the broadcasted mask appended to each query
    # append mask appends the mask to the key
    merged = merge_key_queries(key, query, mask, append_keys=append_keys, append_broadcast_mask=append_broadcast_mask, append_mask=append_mask)
    # print(query[0], mask[0], merged[0])
    merged = merged.reshape(key.shape[0], -1) 
    if not append_keys and mask is not None and append_mask:
        # append the mask to merged if append_mask
        merged = torch.cat([mask.reshape(key.shape[0], -1), merged], axis=-1)
    return merged

class FlatPairNetwork(Network):
    def __init__(self, args):
        # implements an MLP instead of a convolution, which maps from all the inputs
        # (possibly with the masks append) to all the outputs, with the masks applied
        super().__init__(args)
        self.fp = args.factor
        self.no_decode = args.factor_net.no_decode
        self.embed_dim = args.embed_dim
        self.append_keys = args.factor_net.append_keys
        self.append_mask =  args.factor_net.append_mask
        self.append_broadcast_mask = args.factor_net.append_broadcast_mask
        layers = list()

        # pairnets assume keys/queries are already embedded using key_query
        # args.factor.embed_dim is the embedded dimension
        # initialize the internal layers of the pointnet
        self.mlp_layers = list()
        mlp_args = copy.deepcopy(args)
        # TODO: make dependent on key and query dims if using feature masks
        mlp_args.num_inputs = ((self.embed_dim * self.fp.num_objects +
                              int(self.append_mask) * self.fp.num_objects * (self.fp.num_objects if self.append_keys else 1) + # length num_objects, added num_objects times if appending the keys
                              self.append_broadcast_mask * self.fp.num_objects +
                              int(self.append_keys) * self.embed_dim * self.fp.num_objects) if self.embed_dim > 0 else (
                                args.object_dim * self.fp.num_objects +
                                int(self.append_keys) * args.factor.single_obj_dim * self.fp.num_objects +
                                self.append_broadcast_mask * self.fp.num_objects +
                                int(self.append_mask) * self.fp.num_objects * (self.fp.num_objects if self.append_keys else 1)
                              ))

        mlp_args.num_outputs = self.embed_dim * self.fp.num_objects if self.embed_dim > 0 and not self.no_decode else args.output_dim * self.fp.num_objects
        mlp_args.activation_final = mlp_args.activation if self.embed_dim > 0 else args.activation_final
        self.mlp_args = mlp_args
        # print (self.layer_conv_dim, self.hs[-1], args.num_outputs, self.conv_object_dim)
        self.mlp_layer = MLPNetwork(mlp_args)
        layers.append(self.mlp_layer)

        args.factor.final_embed_dim = self.embed_dim if self.embed_dim > 0 else args.factor.key_dim + args.factor.query_dim
        self.aggregate_final = args.aggregate_final
        # self.softmax = nn.Softmax(-1)
        if args.aggregate_final and not self.no_decode: # does not work with a post-channel
            final_args = final_mlp_args(args)
            final_args.num_inputs = args.embed_dim * args.factor.num_objects
            final_args.num_outputs = args.output_dim * args.factor.num_objects
            self.decode = MLPNetwork(final_args)
            layers.append(self.decode)
        else:
            # need a network to go from the embed_dim to the object_dim
            if (not self.no_decode) and self.embed_dim > 0:
                final_args = final_mlp_args(args)
                final_args.num_inputs = args.embed_dim * args.factor.num_objects
                final_args.num_outputs = args.output_dim * args.factor.num_objects
                self.decode = MLPNetwork(final_args)
                layers.append(self.decode)

        self.model = layers
        self.train()
        self.reset_network_parameters()
    
    def forward(self, key, query, mask, ret_settings):
        # assumes only a single key, see keypair for multi-key networks
        x = flatten_key_queries(key, query, mask, append_keys=self.append_keys, append_mask=self.append_mask, append_broadcast_mask=self.append_broadcast_mask) # [batch, embed_dim * 2, n_queries]
        # print(self.mlp_layer, x.shape, key.shape, query.shape, mask.shape, self.embed_dim * self.fp.num_objects)
        x = self.mlp_layer(x)
        embeddings, reduction = x, None
        if self.embed_dim > 0 and not self.no_decode: x = self.decode(x)
        return return_values(ret_settings, x, (key,query), embeddings, reduction)